{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# TUTORIAL 2: Writing Custom Loss Functions \n",
    "\n",
    "**Author** - [Avik Pal](https://avik-pal.github.io) &  [Aniket Das](https://aniket1998.github.io)\n",
    "\n",
    "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/torchgan/torchgan/blob/master/tutorials/Tutorial%202.%20Custom%20Loss%20Functions.ipynb)\n",
    "\n",
    "**TorchGAN** is designed keeping a high degree of extensibility in mind, and allows you to write custom loss functions of your own without having to rewrite the entire training and evaluation loop. This can be done by extending the ```torchgan.losses.GeneratorLoss``` or the ```torchgan.losses.DiscriminatorLoss``` object. \n",
    "\n",
    "All **TorchGAN** losses have a ```train_ops``` associated with it that dictates what steps are to be followed for the loss to be computed and backpropagated. By default, most of the ```train_ops``` follow a **Two Timescale Update Rule (TTUR) ** as follows\n",
    "\n",
    "1. Sample a noise vector from a Normal Distribution $z \\sim \\mathcal{N}(0,\\,1)$\n",
    "3. $d_{real} = D(x)$\n",
    "4. $d_{fake} = D(G(z))$\n",
    "5. $\\mathcal{L} = Loss(d_{real}, d_{fake})$  (*for a Generator Loss $d_{real}$ is generally not computed*)\n",
    "6. Backpropagate over $\\mathcal{L}$\n",
    "\n",
    "Where \n",
    "* $x$ is a sample from the Data Distribution\n",
    "* $D$ is the Discriminator Network\n",
    "* $G$ is the Generator Network\n",
    "\n",
    "Simple losses that conform to this kind of an update rule can be easily implemented by overriding the ```forward``` method of the ```GeneratorLoss``` or ```DiscriminatorLoss``` object\n",
    "\n",
    "** *NB: It is highly recommended that you go over the [Tutorial 1. Introduction to TorchGAN](https://github.com/torchgan/torchgan/blob/master/tutorials/Tutorial%201:%20Introduction%20to%20TorchGAN.ipynb) before reading this* *"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This tutorial assumes that your system has **PyTorch** and **TorchGAN** installed properly. If not, the following code block will try to install the **latest tagged version** of TorchGAN. If you need to use some other version head over to the installation instructions on the [official documentation website](https://torchgan.readthedocs.io/en/latest/)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "try:\n",
    "    import torchgan\n",
    "\n",
    "    print(f\"Existing TorchGAN {torchgan.__version__} installation found\")\n",
    "except ImportError:\n",
    "    import subprocess\n",
    "    import sys\n",
    "\n",
    "    subprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", \"torchgan\"])\n",
    "    import torchgan\n",
    "\n",
    "    print(f\"Installed TorchGAN {torchgan.__version__}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## WRITING CUSTOM LOSSES THAT FOLLOW THE STANDARD UPDATE RULE\n",
    "\n",
    "We shall demonstrate this by implementing the [Boundary Seeking GAN by Hjelm et. al.](https://arxiv.org/abs/1702.08431), also known as BGAN\n",
    "BGAN involves a departure from the Minimax Loss by changing the Generator Loss term in the following manner\n",
    "\n",
    "$$ \\mathcal{L_{generator}} = \\frac{1}{2}E_{z \\sim p(z)}[(log(D(G(z))) - log(1 - D(G(z))))^2]$$\n",
    "\n",
    "where\n",
    "\n",
    "* $z$ is the noise sampled from a probability distribution\n",
    "* $D$ is the Discriminator Network\n",
    "* $G$ is the Generator Network\n",
    "\n",
    "We can observe that the update rule for such a loss confirms with the Standard Update Rule used in **TorchGAN**, hence this loss can be implemented simply by extending the ```torchgan.losses.GeneratorLoss``` object and overriding the ```forward``` method"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# General Imports\n",
    "import os\n",
    "import random\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "\n",
    "# Pytorch and Torchvision Imports\n",
    "import torch\n",
    "import torchvision\n",
    "from torch.optim import Adam\n",
    "import torch.nn as nn\n",
    "import torch.utils.data as data\n",
    "import torchvision.datasets as dsets\n",
    "import torchvision.transforms as transforms\n",
    "import torchvision.utils as vutils\n",
    "\n",
    "# Torchgan Imports\n",
    "import torchgan\n",
    "from torchgan.models import DCGANGenerator, DCGANDiscriminator\n",
    "from torchgan.losses import GeneratorLoss, MinimaxDiscriminatorLoss\n",
    "from torchgan.trainer import Trainer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Set random seed for reproducibility\n",
    "manualSeed = 999\n",
    "random.seed(manualSeed)\n",
    "torch.manual_seed(manualSeed)\n",
    "print(\"Random Seed: \", manualSeed)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## LOADING THE DATASET\n",
    "\n",
    "We make the following transforms before feeding the **MNIST Dataset** into the networks\n",
    "\n",
    "1. The default size of MNIST is $1 \\times 28 \\times 28$. However, by convention, the default input size in **torchgan.models** is a power of 2 and at least 16. Hence we shall be resizing the images to $1 \\times 32 \\times 32$.  One can also **zero-pad** the boundary, without any noticeable difference \n",
    "\n",
    "2. The output quality of GANs is improved when the images are constrained in the range The images are normalized with a mean and standard deviation of **0.5** , thereby constraining most of the inputs in the range (-1, 1)\n",
    "\n",
    "Finally the **torchgan.trainer.Trainer** needs a **DataLoader** as input. So we are going to construct a DataLoader for the MNIST Dataset."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = dsets.MNIST(\n",
    "    root=\"./mnist\",\n",
    "    train=True,\n",
    "    download=True,\n",
    "    transform=transforms.Compose(\n",
    "        [\n",
    "            transforms.Resize((32, 32)),\n",
    "            transforms.ToTensor(),\n",
    "            transforms.Normalize(mean=(0.5,), std=(0.5,)),\n",
    "        ]\n",
    "    ),\n",
    ")\n",
    "\n",
    "dataloader = torch.utils.data.DataLoader(\n",
    "    dataset, shuffle=True, batch_size=512, num_workers=8\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## DEFINING THE BOUNDARY SEEKING LOSS\n",
    "\n",
    "As discussed above, the Boundary Seeking Loss is implementing by overriding the ```forward``` pass of the Generator Loss without requiring any modifications to the ```train_ops```. \n",
    "\n",
    "The ```forward``` method receives the object $d_{fake} = D(G(z))$ as a parameter, where $G$ is the Generator Network, $D$ is the Discriminator Network and $z$ is a sample from the Noise Prior.\n",
    "\n",
    "\n",
    "*NB: This example shall be using the standard DCGAN Generator and Discriminator available in ```torchgan.models```. By default, the last layer of the discriminator does not apply a Sigmoid nonlinearity, the reasson for which has already been discussed in the **Introduction to TorchGAN** tutorial. As a result, the nonlinearity is applied within the loss by a call to ```torch.sigmoid```. One can also alternatively omit this and set the ```last_nonlinearity``` property of the DCGAN Discriminator to ```torch.nn.Sigmoid``` *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class BoundarySeekingLoss(GeneratorLoss):\n",
    "    def forward(self, dx):\n",
    "        dx = torch.sigmoid(dx)\n",
    "        return 0.5 * torch.mean((torch.log(dx) - torch.log(1.0 - dx)) ** 2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "As per the paper, only the Generator Loss is modified. Hence we will use one of the predefined losses, **MinimaxDiscriminatorLoss** for the Discriminator."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "losses = [BoundarySeekingLoss(), MinimaxDiscriminatorLoss()]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## MODEL CONFIGURATION\n",
    "\n",
    "We shall now be defining the neural networks for the discriminator and generator and also set up their respective optimizers. For understanding how to do this please refer to the previous set of tutorials.\n",
    "\n",
    "It should be noted that we have modified the Discriminator Output to use a **nn.Sigmoid** to conform with our Loss Function."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "network_config = {\n",
    "    \"generator\": {\n",
    "        \"name\": DCGANGenerator,\n",
    "        \"args\": {\"out_channels\": 1, \"step_channels\": 8},\n",
    "        \"optimizer\": {\"name\": Adam, \"args\": {\"lr\": 0.0001, \"betas\": (0.5, 0.999)}},\n",
    "    },\n",
    "    \"discriminator\": {\n",
    "        \"name\": DCGANDiscriminator,\n",
    "        \"args\": {\"in_channels\": 1, \"step_channels\": 8},\n",
    "        \"optimizer\": {\"name\": Adam, \"args\": {\"lr\": 0.0001, \"betas\": (0.5, 0.999)}},\n",
    "    },\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "*NB: Training the models are quite expensive. Hence we will train the models for **10** epochs. We recommend using the **GPU runtime** in Colab. The images generated will be very noisy if trained for only **10** epochs, do if you have access to powerful GPUs or want to see realistic samples, I would recommend simply increasing the **epochs** variable (to close to 300) in the next code block*"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if torch.cuda.is_available():\n",
    "    device = torch.device(\"cuda:0\")\n",
    "    # Use deterministic cudnn algorithms\n",
    "    torch.backends.cudnn.deterministic = True\n",
    "else:\n",
    "    device = torch.device(\"cpu\")\n",
    "epochs = 10\n",
    "\n",
    "print(\"Device: {}\".format(device))\n",
    "print(\"Epochs: {}\".format(epochs))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## VISUALIZE THE TRAINING DATA"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Plot some of the training images\n",
    "real_batch = next(iter(dataloader))\n",
    "plt.figure(figsize=(8, 8))\n",
    "plt.axis(\"off\")\n",
    "plt.title(\"Training Images\")\n",
    "plt.imshow(\n",
    "    np.transpose(\n",
    "        vutils.make_grid(\n",
    "            real_batch[0].to(device)[:64], padding=2, normalize=True\n",
    "        ).cpu(),\n",
    "        (1, 2, 0),\n",
    "    )\n",
    ")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## TRAINING BGAN\n",
    "\n",
    "Now we shall start the training. First we need to create the **Trainer** object. When creating this object all the necessary neural nets and their optimizers get instantiated."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer = Trainer(\n",
    "    network_config, losses, ncritic=5, epochs=epochs, sample_size=64, device=device\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer(dataloader)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## VISUALIZING THE GENERATED IMAGES"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Grab a batch of real images from the dataloader\n",
    "real_batch = next(iter(dataloader))\n",
    "\n",
    "# Plot the real images\n",
    "plt.figure(figsize=(10, 10))\n",
    "plt.subplot(1, 2, 1)\n",
    "plt.axis(\"off\")\n",
    "plt.title(\"Real Images\")\n",
    "plt.imshow(\n",
    "    np.transpose(\n",
    "        vutils.make_grid(\n",
    "            real_batch[0].to(device)[:64], padding=5, normalize=True\n",
    "        ).cpu(),\n",
    "        (1, 2, 0),\n",
    "    )\n",
    ")\n",
    "\n",
    "# Plot the fake images from the last epoch\n",
    "plt.subplot(1, 2, 2)\n",
    "plt.axis(\"off\")\n",
    "plt.title(\"Fake Images\")\n",
    "plt.imshow(plt.imread(\"{}/epoch{}_generator.png\".format(trainer.recon, epochs)))\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
